import matplotlib.pyplot as plt

fig, axs = plt.subplots(8,8, figsize=(18,18))
for i in range(8):
    for j in range(8):
        axs[i,j].imshow(X_train[i*8+j])
        axs[i,j].axis('off')
plt.show()
